{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import os\n",
    "from transformers import (\n",
    "    AutoConfig,\n",
    "    AutoModel,\n",
    "    AutoTokenizer,\n",
    "    AutoTokenizer,\n",
    "    DataCollatorForSeq2Seq,\n",
    "    HfArgumentParser,\n",
    "    Seq2SeqTrainingArguments,\n",
    "    set_seed,\n",
    ")\n",
    "from arguments import ModelArguments\n",
    "import argparse\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "model_args = ModelArguments(model_name_or_path=\"pretrain_model\",\n",
    "                                ptuning_checkpoint=\"./ckpt/ptuningv2\",\n",
    "                                pre_seq_len=128,\n",
    "                                quantization_bit=8,)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Explicitly passing a `revision` is encouraged when loading a configuration with custom code to ensure no malicious code has been contributed in a newer revision.\n",
      "Explicitly passing a `revision` is encouraged when loading a model with custom code to ensure no malicious code has been contributed in a newer revision.\n"
     ]
    }
   ],
   "source": [
    "config = AutoConfig.from_pretrained(model_args.model_name_or_path, trust_remote_code=True)\n",
    "config.pre_seq_len = model_args.pre_seq_len\n",
    "config.prefix_projection = model_args.prefix_projection\n",
    "\n",
    "tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path, trust_remote_code=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Explicitly passing a `revision` is encouraged when loading a model with custom code to ensure no malicious code has been contributed in a newer revision.\n",
      "Loading checkpoint shards: 100%|██████████| 8/8 [00:10<00:00,  1.32s/it]\n",
      "Some weights of ChatGLMForConditionalGeneration were not initialized from the model checkpoint at pretrain_model and are newly initialized: ['transformer.prefix_encoder.embedding.weight']\n",
      "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<All keys matched successfully>"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model = AutoModel.from_pretrained(model_args.model_name_or_path, config=config, trust_remote_code=True)\n",
    "prefix_state_dict = torch.load(os.path.join(model_args.ptuning_checkpoint, \"pytorch_model.bin\"))\n",
    "new_prefix_state_dict = {}\n",
    "for k, v in prefix_state_dict.items():\n",
    "    if k.startswith(\"transformer.prefix_encoder.\"):\n",
    "        new_prefix_state_dict[k[len(\"transformer.prefix_encoder.\"):]] = v\n",
    "model.transformer.prefix_encoder.load_state_dict(new_prefix_state_dict)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Quantized to 8 bit\n"
     ]
    }
   ],
   "source": [
    "print(f\"Quantized to {model_args.quantization_bit} bit\")\n",
    "model = model.quantize(model_args.quantization_bit)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = model.half()\n",
    "model.transformer.prefix_encoder.float()\n",
    "model = model.cuda()\n",
    "model = model.eval()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "你好，这种情况应该是咽喉炎症引起的，可以口服阿奇霉素，利咽颗粒试试，保持室内空气流通，平时饮食中加点水果蔬菜，另外可以口服蓝芩口服液调理一下。慢慢会改善的，祝健康\n"
     ]
    }
   ],
   "source": [
    "question = \"咽痛，咳嗽，多痰\"\n",
    "response, history = model.chat(tokenizer, question, history=[],max_length=2048,\n",
    "                                        eos_token_id=config.eos_token_id,\n",
    "                                        do_sample=True, top_p=0.7, temperature=1,\n",
    "                                        )\n",
    "print(response)\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "origin",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.0"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
